Spaces:
Running
Running
| """Authentication dependencies for FastAPI routes. | |
| - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user. | |
| - In production: validates Bearer tokens or cookies against HF OAuth. | |
| """ | |
| import logging | |
| import os | |
| import time | |
| from collections.abc import Iterable | |
| from hashlib import sha256 | |
| from typing import Any | |
| import httpx | |
| from fastapi import HTTPException, Request, status | |
| from openai_compat import V1APIError | |
| from agent.core.hf_tokens import bearer_token_from_header, clean_hf_token | |
| from agent.core.hf_access import fetch_whoami_v2, normalize_hf_user_plan | |
| logger = logging.getLogger(__name__) | |
| OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") | |
| # Auth is on when HF OAuth is configured (the web Space) OR when explicitly | |
| # forced (REQUIRE_API_AUTH=1, set by the API-only Space image, which has no | |
| # OAuth app but must never fall back to the dev-mode identity). | |
| AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) or os.environ.get( | |
| "REQUIRE_API_AUTH", "" | |
| ) not in ("", "0", "false") | |
| # Simple in-memory token cache: token -> (user_info, expiry_time) | |
| _token_cache: dict[str, tuple[dict[str, Any], float]] = {} | |
| TOKEN_CACHE_TTL = 300 # 5 minutes | |
| DEV_USER: dict[str, Any] = { | |
| "user_id": "dev", | |
| "username": "dev", | |
| "authenticated": True, | |
| "plan": "pro", # Dev uses the Pro web default model. | |
| } | |
| INTERNAL_HF_TOKEN_KEY = "_hf_token" | |
| OAUTH_SCOPE_COOKIE = "hf_oauth_scope_hash" | |
| REQUIRED_OAUTH_SCOPES: tuple[str, ...] = ( | |
| "openid", | |
| "profile", | |
| "read-billing", | |
| "read-repos", | |
| "write-repos", | |
| "contribute-repos", | |
| "manage-repos", | |
| "write-collections", | |
| "inference-api", | |
| "jobs", | |
| "write-discussions", | |
| ) | |
| # Log the whoami-v2 shape once at DEBUG so we can confirm the production Pro | |
| # signal without hammering the HF API. | |
| _WHOAMI_SHAPE_LOGGED = False | |
| def normalize_oauth_scopes(scopes: Iterable[str]) -> tuple[str, ...]: | |
| """Return stable, de-duplicated OAuth scopes preserving declaration order.""" | |
| seen: set[str] = set() | |
| normalized: list[str] = [] | |
| for scope in scopes: | |
| value = str(scope).strip() | |
| if not value or value in seen: | |
| continue | |
| seen.add(value) | |
| normalized.append(value) | |
| return tuple(normalized) | |
| def configured_oauth_scopes() -> tuple[str, ...]: | |
| """Return the scopes this backend should request from HF OAuth. | |
| Spaces expose README ``hf_oauth_scopes`` through ``OAUTH_SCOPES``. Unioning | |
| that value with the app-required scopes keeps the local request and Space | |
| metadata in sync while ensuring new required scopes are never omitted. | |
| """ | |
| env_scopes = os.environ.get("OAUTH_SCOPES", "").split() | |
| return normalize_oauth_scopes((*env_scopes, *REQUIRED_OAUTH_SCOPES)) | |
| def oauth_scope_fingerprint(scopes: Iterable[str] | None = None) -> str: | |
| """Return a non-secret fingerprint for the current OAuth scope contract.""" | |
| scope_list = configured_oauth_scopes() if scopes is None else scopes | |
| payload = " ".join(sorted(normalize_oauth_scopes(scope_list))) | |
| return sha256(payload.encode("utf-8")).hexdigest()[:16] | |
| def _cookie_has_current_oauth_scope_marker(request: Request) -> bool: | |
| return request.cookies.get(OAUTH_SCOPE_COOKIE) == oauth_scope_fingerprint() | |
| async def _validate_token(token: str) -> dict[str, Any] | None: | |
| """Validate a token against HF OAuth userinfo endpoint. | |
| Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls. | |
| """ | |
| now = time.time() | |
| # Check cache | |
| if token in _token_cache: | |
| user_info, expiry = _token_cache[token] | |
| if now < expiry: | |
| return user_info | |
| del _token_cache[token] | |
| # Validate against HF | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| try: | |
| response = await client.get( | |
| f"{OPENID_PROVIDER_URL}/oauth/userinfo", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| if response.status_code != 200: | |
| logger.debug("Token validation failed: status %d", response.status_code) | |
| return None | |
| user_info = response.json() | |
| _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL) | |
| return user_info | |
| except httpx.HTTPError as e: | |
| logger.warning("Token validation error: %s", e) | |
| return None | |
| def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: | |
| """Build a normalized user dict from HF userinfo response.""" | |
| return { | |
| "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")), | |
| "username": user_info.get("preferred_username", "unknown"), | |
| "name": user_info.get("name"), | |
| "picture": user_info.get("picture"), | |
| "authenticated": True, | |
| } | |
| def _normalize_user_plan(whoami: Any) -> str: | |
| """Normalize a whoami-v2 payload to the app's supported plan tiers.""" | |
| return normalize_hf_user_plan(whoami) or "free" | |
| async def _fetch_user_plan(token: str) -> str: | |
| """Look up the user's HF plan via /api/whoami-v2. | |
| Returns 'free' | 'pro'. Non-200, network errors, or an unknown | |
| payload shape all collapse to 'free' β safe default; we'd rather avoid | |
| selecting the Pro default on bad data. | |
| """ | |
| global _WHOAMI_SHAPE_LOGGED | |
| whoami = await fetch_whoami_v2(token) | |
| if whoami is None: | |
| return "free" | |
| if not _WHOAMI_SHAPE_LOGGED: | |
| _WHOAMI_SHAPE_LOGGED = True | |
| logger.debug( | |
| "whoami-v2 payload keys: %s (sample values: isPro=%r)", | |
| sorted(whoami.keys()) | |
| if isinstance(whoami, dict) | |
| else type(whoami).__name__, | |
| whoami.get("isPro") if isinstance(whoami, dict) else None, | |
| ) | |
| return _normalize_user_plan(whoami) | |
| async def _extract_user_from_token(token: str) -> dict[str, Any] | None: | |
| """Validate a token and return a user dict, or None.""" | |
| user_info = await _validate_token(token) | |
| if user_info is None: | |
| return None | |
| user = _user_from_info(user_info) | |
| user["plan"] = await _fetch_user_plan(token) | |
| user[INTERNAL_HF_TOKEN_KEY] = clean_hf_token(token) | |
| return user | |
| async def _dev_user_from_env() -> dict[str, Any]: | |
| """Use HF_TOKEN as the dev identity when available. | |
| Local dev often runs without OAuth, but session trace uploads still need a | |
| real HF namespace. Deriving the dev user from HF_TOKEN keeps local uploads | |
| pointed at the token owner's dataset instead of dev/ml-intern-sessions. | |
| """ | |
| token = clean_hf_token(os.environ.get("HF_TOKEN", "")) | |
| if not token: | |
| return dict(DEV_USER) | |
| whoami = await fetch_whoami_v2(token) | |
| if not isinstance(whoami, dict): | |
| return dict(DEV_USER) | |
| username = None | |
| for key in ("name", "user", "preferred_username"): | |
| value = whoami.get(key) | |
| if isinstance(value, str) and value: | |
| username = value | |
| break | |
| if not username: | |
| return dict(DEV_USER) | |
| return { | |
| "user_id": username, | |
| "username": username, | |
| "authenticated": True, | |
| "plan": await _fetch_user_plan(token), | |
| INTERNAL_HF_TOKEN_KEY: token, | |
| } | |
| async def get_current_user(request: Request) -> dict[str, Any]: | |
| """FastAPI dependency: extract and validate the current user. | |
| Checks (in order): | |
| 1. Authorization: Bearer <token> header | |
| 2. hf_access_token cookie | |
| In dev mode (AUTH_ENABLED=False), uses HF_TOKEN as the user when possible. | |
| """ | |
| if not AUTH_ENABLED: | |
| return await _dev_user_from_env() | |
| # Bearer callers manage token lifecycle themselves; only browser cookie | |
| # auth is forced through the scope-freshness marker below. | |
| token = bearer_token_from_header(request.headers.get("Authorization", "")) | |
| if token: | |
| user = await _extract_user_from_token(token) | |
| if user: | |
| return user | |
| # Try cookie | |
| token = request.cookies.get("hf_access_token") | |
| if token: | |
| if not _cookie_has_current_oauth_scope_marker(request): | |
| logger.info( | |
| "Rejecting stale HF OAuth cookie; current scopes require refresh." | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Authentication scopes changed. Please log in again.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| user = await _extract_user_from_token(token) | |
| if user: | |
| return user | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Not authenticated. Please log in via /auth/login.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # ββ /v1 developer-API auth βββββββββββββββββββββββββββββββββββββββββββ | |
| # The /v1 surface accepts plain `hf_...` user access tokens (not just OAuth | |
| # access tokens), validated against /api/whoami-v2 which accepts both kinds. | |
| _api_token_cache: dict[str, tuple[dict[str, Any], float]] = {} | |
| def _api_auth_error(message: str) -> V1APIError: | |
| return V1APIError( | |
| 401, | |
| message, | |
| code="invalid_api_key", | |
| error_type="authentication_error", | |
| ) | |
| async def get_api_user(request: Request) -> dict[str, Any]: | |
| """FastAPI dependency for /v1 routes: Bearer-token-only authentication. | |
| Validates the token via whoami-v2 so plain user access tokens work. No | |
| cookie path β developer API callers manage tokens themselves. In dev mode | |
| (AUTH_ENABLED=False) falls back to the HF_TOKEN-derived dev identity. | |
| """ | |
| if not AUTH_ENABLED: | |
| return await _dev_user_from_env() | |
| token = bearer_token_from_header(request.headers.get("Authorization", "")) | |
| if not token: | |
| raise _api_auth_error( | |
| "Missing Authorization header. Pass your Hugging Face token as " | |
| "'Authorization: Bearer hf_...'." | |
| ) | |
| now = time.time() | |
| cached = _api_token_cache.get(token) | |
| if cached and now < cached[1]: | |
| return dict(cached[0]) | |
| whoami = await fetch_whoami_v2(token) | |
| if not isinstance(whoami, dict): | |
| raise _api_auth_error( | |
| "Invalid Hugging Face token (whoami-v2 validation failed)." | |
| ) | |
| if whoami.get("type") not in (None, "user"): | |
| raise _api_auth_error( | |
| "Organization tokens are not supported; use a user access token." | |
| ) | |
| username = None | |
| for key in ("name", "user", "preferred_username"): | |
| value = whoami.get(key) | |
| if isinstance(value, str) and value: | |
| username = value | |
| break | |
| if not username: | |
| raise _api_auth_error("Could not resolve a username for this token.") | |
| user: dict[str, Any] = { | |
| "user_id": str(whoami.get("id") or username), | |
| "username": username, | |
| "name": whoami.get("fullname"), | |
| "authenticated": True, | |
| "plan": normalize_hf_user_plan(whoami) or "free", | |
| INTERNAL_HF_TOKEN_KEY: clean_hf_token(token), | |
| } | |
| _api_token_cache[token] = (user, now + TOKEN_CACHE_TTL) | |
| return dict(user) | |