File size: 4,444 Bytes
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caa768d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caa768d
c84fdae
 
 
 
 
 
 
 
 
caa768d
c84fdae
 
 
 
caa768d
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Firebase JWT verification middleware."""

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from src.utils.database import AsyncSessionLocal
from src.models.database import User
from src.config import get_settings
import httpx
import jwt as pyjwt
import os
import time

bearer_scheme = HTTPBearer(auto_error=False)
JWKS_URL = f"https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com"
_PUBLIC_KEYS_CACHE = {"expires_at": 0.0, "keys": None}


def _firebase_project_id() -> str:
    return os.getenv("FIREBASE_PROJECT_ID", "") or get_settings().firebase_project_id


async def _get_firebase_public_keys() -> dict:
    now = time.time()
    if _PUBLIC_KEYS_CACHE["keys"] and _PUBLIC_KEYS_CACHE["expires_at"] > now:
        return _PUBLIC_KEYS_CACHE["keys"]

    # Try certifi first (production-safe). Fall back to verify=False if SSL fails
    # (Windows dev machines often have antivirus interfering with cert chain).
    import certifi
    try:
        async with httpx.AsyncClient(verify=certifi.where(), timeout=10) as client:
            r = await client.get(JWKS_URL)
            r.raise_for_status()
            data = r.json()
    except Exception as e:
        msg = str(e).lower()
        if "certificate" in msg or "ssl" in msg:
            from loguru import logger
            logger.warning(f"[AUTH] SSL verify failed locally, retrying without verify: {e}")
            async with httpx.AsyncClient(verify=False, timeout=10) as client:
                r = await client.get(JWKS_URL)
                r.raise_for_status()
                data = r.json()
        else:
            raise

    _PUBLIC_KEYS_CACHE["keys"] = data
    _PUBLIC_KEYS_CACHE["expires_at"] = now + 3600
    return data


async def verify_firebase_token_string(token: str) -> dict:
    """Verify a raw Firebase JWT and return its decoded payload."""
    project_id = _firebase_project_id()
    if not project_id:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Firebase project is not configured",
        )

    try:
        header = pyjwt.get_unverified_header(token)
        keys_data = await _get_firebase_public_keys()

        public_key = None
        for key in keys_data.get("keys", []):
            if key.get("kid") == header.get("kid"):
                public_key = pyjwt.algorithms.RSAAlgorithm.from_jwk(key)
                break

        if not public_key:
            print(f"[AUTH] No matching key for kid={header.get('kid')}; available={[k.get('kid') for k in keys_data.get('keys', [])]}")
            raise HTTPException(status_code=401, detail="Invalid token key")

        return pyjwt.decode(
            token,
            public_key,
            algorithms=["RS256"],
            audience=project_id,
        )
    except pyjwt.ExpiredSignatureError:
        print("[AUTH] Token expired")
        raise HTTPException(status_code=401, detail="Token expired")
    except HTTPException:
        raise
    except Exception as e:
        print(f"[AUTH] Verify failed: {type(e).__name__}: {e}; project_id={project_id}")
        raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")


async def verify_firebase_token(
    credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> dict:
    if not credentials:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")

    return await verify_firebase_token_string(credentials.credentials)


async def get_current_user(payload: dict = Depends(verify_firebase_token)) -> "User":
    """Get or create user from Firebase token."""
    firebase_uid = payload.get("uid") or payload.get("sub")
    email = payload.get("email", "")

    async with AsyncSessionLocal() as session:
        result = await session.execute(select(User).where(User.firebase_uid == firebase_uid))
        user = result.scalar_one_or_none()

        if not user:
            user = User(
                firebase_uid=firebase_uid,
                email=email,
                display_name=payload.get("name", email.split("@")[0]),
                photo_url=payload.get("picture"),
                is_active=True,
            )
            session.add(user)
            await session.commit()
            await session.refresh(user)

        return user