| """Authentication routes for HF OAuth. |
| |
| Handles the OAuth 2.0 authorization code flow with HF as provider. |
| After successful auth, sets an HttpOnly cookie with the access token. |
| """ |
|
|
| import logging |
| import os |
| import secrets |
| import time |
| from urllib.parse import urlencode |
|
|
| import httpx |
| from dependencies import ( |
| AUTH_ENABLED, |
| OAUTH_SCOPE_COOKIE, |
| REQUIRED_OAUTH_SCOPES, |
| configured_oauth_scopes, |
| get_current_user, |
| oauth_scope_fingerprint, |
| ) |
| from fastapi import APIRouter, Depends, HTTPException, Request |
| from fastapi.responses import RedirectResponse |
|
|
| router = APIRouter(prefix="/auth", tags=["auth"]) |
| logger = logging.getLogger(__name__) |
|
|
| |
| OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") |
| OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") |
| OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") |
| OAUTH_SCOPES = configured_oauth_scopes() |
|
|
| |
| _OAUTH_STATE_TTL = 300 |
| oauth_states: dict[str, dict] = {} |
|
|
|
|
| def _missing_required_scopes(token_data: dict) -> set[str]: |
| raw_scopes = token_data.get("scope") |
| if not isinstance(raw_scopes, str) or not raw_scopes.strip(): |
| logger.debug("OAuth token response omitted a usable scope field") |
| return set() |
| granted = set(raw_scopes.split()) |
| return set(REQUIRED_OAUTH_SCOPES) - granted |
|
|
|
|
| def _cleanup_expired_states() -> None: |
| """Remove expired OAuth states to prevent memory growth.""" |
| now = time.time() |
| expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)] |
| for k in expired: |
| del oauth_states[k] |
|
|
|
|
| def get_redirect_uri(request: Request) -> str: |
| """Get the OAuth callback redirect URI.""" |
| |
| space_host = os.environ.get("SPACE_HOST") |
| if space_host: |
| return f"https://{space_host}/auth/callback" |
| |
| return str(request.url_for("oauth_callback")) |
|
|
|
|
| @router.get("/login") |
| async def oauth_login(request: Request) -> RedirectResponse: |
| """Initiate OAuth login flow.""" |
| if not OAUTH_CLIENT_ID: |
| raise HTTPException( |
| status_code=500, |
| detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.", |
| ) |
|
|
| |
| _cleanup_expired_states() |
|
|
| |
| state = secrets.token_urlsafe(32) |
| oauth_states[state] = { |
| "redirect_uri": get_redirect_uri(request), |
| "expires_at": time.time() + _OAUTH_STATE_TTL, |
| } |
|
|
| |
| |
| |
| params = { |
| "client_id": OAUTH_CLIENT_ID, |
| "redirect_uri": get_redirect_uri(request), |
| "scope": " ".join(OAUTH_SCOPES), |
| "response_type": "code", |
| "state": state, |
| } |
| auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" |
|
|
| return RedirectResponse(url=auth_url) |
|
|
|
|
| @router.get("/callback") |
| async def oauth_callback( |
| request: Request, code: str = "", state: str = "" |
| ) -> RedirectResponse: |
| """Handle OAuth callback.""" |
| |
| if state not in oauth_states: |
| raise HTTPException(status_code=400, detail="Invalid state parameter") |
|
|
| stored_state = oauth_states.pop(state) |
| redirect_uri = stored_state["redirect_uri"] |
|
|
| if not code: |
| raise HTTPException(status_code=400, detail="No authorization code provided") |
|
|
| |
| token_url = f"{OPENID_PROVIDER_URL}/oauth/token" |
| async with httpx.AsyncClient() as client: |
| try: |
| response = await client.post( |
| token_url, |
| data={ |
| "grant_type": "authorization_code", |
| "code": code, |
| "redirect_uri": redirect_uri, |
| "client_id": OAUTH_CLIENT_ID, |
| "client_secret": OAUTH_CLIENT_SECRET, |
| }, |
| ) |
| response.raise_for_status() |
| token_data = response.json() |
| except httpx.HTTPError as e: |
| raise HTTPException(status_code=500, detail=f"Token exchange failed: {e}") |
|
|
| |
| access_token = token_data.get("access_token") |
| if not access_token: |
| raise HTTPException( |
| status_code=500, |
| detail="Token exchange succeeded but no access_token was returned.", |
| ) |
| missing_scopes = _missing_required_scopes(token_data) |
| if missing_scopes: |
| raise HTTPException( |
| status_code=403, |
| detail=( |
| "OAuth token is missing required scopes: " |
| + ", ".join(sorted(missing_scopes)) |
| ), |
| ) |
|
|
| |
| async with httpx.AsyncClient() as client: |
| try: |
| userinfo_response = await client.get( |
| f"{OPENID_PROVIDER_URL}/oauth/userinfo", |
| headers={"Authorization": f"Bearer {access_token}"}, |
| ) |
| userinfo_response.raise_for_status() |
| except httpx.HTTPError: |
| pass |
|
|
| |
| |
| is_production = bool(os.environ.get("SPACE_HOST")) |
| response = RedirectResponse(url="/", status_code=302) |
| response.set_cookie( |
| key="hf_access_token", |
| value=access_token, |
| httponly=True, |
| secure=is_production, |
| samesite="lax", |
| max_age=3600 * 24 * 7, |
| path="/", |
| ) |
| response.set_cookie( |
| key=OAUTH_SCOPE_COOKIE, |
| value=oauth_scope_fingerprint(OAUTH_SCOPES), |
| httponly=True, |
| secure=is_production, |
| samesite="lax", |
| max_age=3600 * 24 * 7, |
| path="/", |
| ) |
| return response |
|
|
|
|
| @router.get("/logout") |
| async def logout() -> RedirectResponse: |
| """Log out the user by clearing the auth cookie.""" |
| response = RedirectResponse(url="/") |
| response.delete_cookie(key="hf_access_token", path="/") |
| response.delete_cookie(key=OAUTH_SCOPE_COOKIE, path="/") |
| return response |
|
|
|
|
| @router.get("/status") |
| async def auth_status() -> dict: |
| """Check if OAuth is enabled on this instance.""" |
| return {"auth_enabled": AUTH_ENABLED} |
|
|
|
|
| @router.get("/me") |
| async def get_me(user: dict = Depends(get_current_user)) -> dict: |
| """Get current user info. Returns the authenticated user or dev user. |
| |
| Uses the shared auth dependency which handles cookie + Bearer token. |
| """ |
| return {key: value for key, value in user.items() if not key.startswith("_")} |
|
|