| | """Authentication dependencies for FastAPI routes.
|
| |
|
| | Provides auth validation for both REST and WebSocket endpoints.
|
| | - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user.
|
| | - In production: validates Bearer tokens or cookies against HF OAuth.
|
| | """
|
| |
|
| | import logging
|
| | import os
|
| | import time
|
| | from typing import Any
|
| |
|
| | import httpx
|
| | from fastapi import HTTPException, Request, WebSocket, status
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
|
| | AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", ""))
|
| |
|
| |
|
| | _token_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
| | TOKEN_CACHE_TTL = 300
|
| |
|
| | DEV_USER: dict[str, Any] = {
|
| | "user_id": "dev",
|
| | "username": "dev",
|
| | "authenticated": True,
|
| | }
|
| |
|
| |
|
| | async def _validate_token(token: str) -> dict[str, Any] | None:
|
| | """Validate a token against HF OAuth userinfo endpoint.
|
| |
|
| | Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls.
|
| | """
|
| | now = time.time()
|
| |
|
| |
|
| | if token in _token_cache:
|
| | user_info, expiry = _token_cache[token]
|
| | if now < expiry:
|
| | return user_info
|
| | del _token_cache[token]
|
| |
|
| |
|
| | async with httpx.AsyncClient(timeout=10.0) as client:
|
| | try:
|
| | response = await client.get(
|
| | f"{OPENID_PROVIDER_URL}/oauth/userinfo",
|
| | headers={"Authorization": f"Bearer {token}"},
|
| | )
|
| | if response.status_code != 200:
|
| | logger.debug("Token validation failed: status %d", response.status_code)
|
| | return None
|
| | user_info = response.json()
|
| | _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL)
|
| | return user_info
|
| | except httpx.HTTPError as e:
|
| | logger.warning("Token validation error: %s", e)
|
| | return None
|
| |
|
| |
|
| | def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
|
| | """Build a normalized user dict from HF userinfo response."""
|
| | return {
|
| | "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")),
|
| | "username": user_info.get("preferred_username", "unknown"),
|
| | "name": user_info.get("name"),
|
| | "picture": user_info.get("picture"),
|
| | "authenticated": True,
|
| | }
|
| |
|
| |
|
| | async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
|
| | """Validate a token and return a user dict, or None."""
|
| | user_info = await _validate_token(token)
|
| | if user_info:
|
| | return _user_from_info(user_info)
|
| | return None
|
| |
|
| |
|
| | async def get_current_user(request: Request) -> dict[str, Any]:
|
| | """FastAPI dependency: extract and validate the current user.
|
| |
|
| | Checks (in order):
|
| | 1. Authorization: Bearer <token> header
|
| | 2. hf_access_token cookie
|
| |
|
| | In dev mode (AUTH_ENABLED=False), returns a default dev user.
|
| | """
|
| | if not AUTH_ENABLED:
|
| | return DEV_USER
|
| |
|
| |
|
| | auth_header = request.headers.get("Authorization", "")
|
| | if auth_header.startswith("Bearer "):
|
| | token = auth_header[7:]
|
| | user = await _extract_user_from_token(token)
|
| | if user:
|
| | return user
|
| |
|
| |
|
| | token = request.cookies.get("hf_access_token")
|
| | if token:
|
| | user = await _extract_user_from_token(token)
|
| | if user:
|
| | return user
|
| |
|
| | raise HTTPException(
|
| | status_code=status.HTTP_401_UNAUTHORIZED,
|
| | detail="Not authenticated. Please log in via /auth/login.",
|
| | headers={"WWW-Authenticate": "Bearer"},
|
| | )
|
| |
|
| |
|
| | async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None:
|
| | """Extract and validate user from WebSocket connection.
|
| |
|
| | WebSocket doesn't support custom headers from browser, so we check:
|
| | 1. ?token= query parameter
|
| | 2. hf_access_token cookie (sent automatically for same-origin)
|
| |
|
| | Returns user dict or None if not authenticated.
|
| | In dev mode, returns the default dev user.
|
| | """
|
| | if not AUTH_ENABLED:
|
| | return DEV_USER
|
| |
|
| |
|
| | token = websocket.query_params.get("token")
|
| | if token:
|
| | user = await _extract_user_from_token(token)
|
| | if user:
|
| | return user
|
| |
|
| |
|
| | token = websocket.cookies.get("hf_access_token")
|
| | if token:
|
| | user = await _extract_user_from_token(token)
|
| | if user:
|
| | return user
|
| |
|
| | return None
|
| |
|