orbis-backend / src /api /middleware /firebase_auth.py
Deusxx1234's picture
feat: granular agent sub-steps, immediate WS push, all backend fixes
caa768d
"""Firebase JWT verification middleware."""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from src.utils.database import AsyncSessionLocal
from src.models.database import User
from src.config import get_settings
import httpx
import jwt as pyjwt
import os
import time
bearer_scheme = HTTPBearer(auto_error=False)
JWKS_URL = f"https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com"
_PUBLIC_KEYS_CACHE = {"expires_at": 0.0, "keys": None}
def _firebase_project_id() -> str:
return os.getenv("FIREBASE_PROJECT_ID", "") or get_settings().firebase_project_id
async def _get_firebase_public_keys() -> dict:
now = time.time()
if _PUBLIC_KEYS_CACHE["keys"] and _PUBLIC_KEYS_CACHE["expires_at"] > now:
return _PUBLIC_KEYS_CACHE["keys"]
# Try certifi first (production-safe). Fall back to verify=False if SSL fails
# (Windows dev machines often have antivirus interfering with cert chain).
import certifi
try:
async with httpx.AsyncClient(verify=certifi.where(), timeout=10) as client:
r = await client.get(JWKS_URL)
r.raise_for_status()
data = r.json()
except Exception as e:
msg = str(e).lower()
if "certificate" in msg or "ssl" in msg:
from loguru import logger
logger.warning(f"[AUTH] SSL verify failed locally, retrying without verify: {e}")
async with httpx.AsyncClient(verify=False, timeout=10) as client:
r = await client.get(JWKS_URL)
r.raise_for_status()
data = r.json()
else:
raise
_PUBLIC_KEYS_CACHE["keys"] = data
_PUBLIC_KEYS_CACHE["expires_at"] = now + 3600
return data
async def verify_firebase_token_string(token: str) -> dict:
"""Verify a raw Firebase JWT and return its decoded payload."""
project_id = _firebase_project_id()
if not project_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Firebase project is not configured",
)
try:
header = pyjwt.get_unverified_header(token)
keys_data = await _get_firebase_public_keys()
public_key = None
for key in keys_data.get("keys", []):
if key.get("kid") == header.get("kid"):
public_key = pyjwt.algorithms.RSAAlgorithm.from_jwk(key)
break
if not public_key:
print(f"[AUTH] No matching key for kid={header.get('kid')}; available={[k.get('kid') for k in keys_data.get('keys', [])]}")
raise HTTPException(status_code=401, detail="Invalid token key")
return pyjwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=project_id,
)
except pyjwt.ExpiredSignatureError:
print("[AUTH] Token expired")
raise HTTPException(status_code=401, detail="Token expired")
except HTTPException:
raise
except Exception as e:
print(f"[AUTH] Verify failed: {type(e).__name__}: {e}; project_id={project_id}")
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
async def verify_firebase_token(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> dict:
if not credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
return await verify_firebase_token_string(credentials.credentials)
async def get_current_user(payload: dict = Depends(verify_firebase_token)) -> "User":
"""Get or create user from Firebase token."""
firebase_uid = payload.get("uid") or payload.get("sub")
email = payload.get("email", "")
async with AsyncSessionLocal() as session:
result = await session.execute(select(User).where(User.firebase_uid == firebase_uid))
user = result.scalar_one_or_none()
if not user:
user = User(
firebase_uid=firebase_uid,
email=email,
display_name=payload.get("name", email.split("@")[0]),
photo_url=payload.get("picture"),
is_active=True,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user