File size: 6,928 Bytes
79b2fcc 67b16c6 e90638c 67b16c6 79b2fcc 67b16c6 e90638c 79b2fcc 67b16c6 e90638c 67b16c6 e90638c 67b16c6 79b2fcc 67b16c6 e90638c 79b2fcc 67b16c6 79b2fcc 67b16c6 79b2fcc 67b16c6 bed4e38 67b16c6 8b35696 67b16c6 8460d28 67b16c6 79b2fcc e90638c 67b16c6 79b2fcc 5af3ab5 79b2fcc e90638c 79b2fcc 67b16c6 79b2fcc e90638c 79b2fcc 67b16c6 79b2fcc 67b16c6 79b2fcc 5de6e8f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | """Authentication routes for HF OAuth.
Handles the OAuth 2.0 authorization code flow with HF as provider.
After successful auth, sets an HttpOnly cookie with the access token.
"""
import logging
import os
import secrets
import time
from urllib.parse import urlencode
import httpx
from dependencies import (
AUTH_ENABLED,
OAUTH_SCOPE_COOKIE,
REQUIRED_OAUTH_SCOPES,
configured_oauth_scopes,
get_current_user,
oauth_scope_fingerprint,
)
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse
router = APIRouter(prefix="/auth", tags=["auth"])
logger = logging.getLogger(__name__)
# OAuth configuration from environment
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
OAUTH_SCOPES = configured_oauth_scopes()
# In-memory OAuth state store with expiry (5 min TTL)
_OAUTH_STATE_TTL = 300
oauth_states: dict[str, dict] = {}
def _missing_required_scopes(token_data: dict) -> set[str]:
raw_scopes = token_data.get("scope")
if not isinstance(raw_scopes, str) or not raw_scopes.strip():
logger.debug("OAuth token response omitted a usable scope field")
return set()
granted = set(raw_scopes.split())
return set(REQUIRED_OAUTH_SCOPES) - granted
def _cleanup_expired_states() -> None:
"""Remove expired OAuth states to prevent memory growth."""
now = time.time()
expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)]
for k in expired:
del oauth_states[k]
def get_redirect_uri(request: Request) -> str:
"""Get the OAuth callback redirect URI."""
# In HF Spaces, use the SPACE_HOST if available
space_host = os.environ.get("SPACE_HOST")
if space_host:
return f"https://{space_host}/auth/callback"
# Otherwise construct from request
return str(request.url_for("oauth_callback"))
@router.get("/login")
async def oauth_login(request: Request) -> RedirectResponse:
"""Initiate OAuth login flow."""
if not OAUTH_CLIENT_ID:
raise HTTPException(
status_code=500,
detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.",
)
# Clean up expired states to prevent memory growth
_cleanup_expired_states()
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
oauth_states[state] = {
"redirect_uri": get_redirect_uri(request),
"expires_at": time.time() + _OAUTH_STATE_TTL,
}
# Build authorization URL. We no longer suggest a default `orgIds` —
# users no longer need to join the ML Agent Explorers org to use the
# app, and HF Jobs are billed per-namespace via credits.
params = {
"client_id": OAUTH_CLIENT_ID,
"redirect_uri": get_redirect_uri(request),
"scope": " ".join(OAUTH_SCOPES),
"response_type": "code",
"state": state,
}
auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}"
return RedirectResponse(url=auth_url)
@router.get("/callback")
async def oauth_callback(
request: Request, code: str = "", state: str = ""
) -> RedirectResponse:
"""Handle OAuth callback."""
# Verify state
if state not in oauth_states:
raise HTTPException(status_code=400, detail="Invalid state parameter")
stored_state = oauth_states.pop(state)
redirect_uri = stored_state["redirect_uri"]
if not code:
raise HTTPException(status_code=400, detail="No authorization code provided")
# Exchange code for token
token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
async with httpx.AsyncClient() as client:
try:
response = await client.post(
token_url,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
},
)
response.raise_for_status()
token_data = response.json()
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"Token exchange failed: {e}")
# Get user info
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(
status_code=500,
detail="Token exchange succeeded but no access_token was returned.",
)
missing_scopes = _missing_required_scopes(token_data)
if missing_scopes:
raise HTTPException(
status_code=403,
detail=(
"OAuth token is missing required scopes: "
+ ", ".join(sorted(missing_scopes))
),
)
# Fetch user info (optional — failure is not fatal)
async with httpx.AsyncClient() as client:
try:
userinfo_response = await client.get(
f"{OPENID_PROVIDER_URL}/oauth/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
)
userinfo_response.raise_for_status()
except httpx.HTTPError:
pass # user_info not required for auth flow
# Set access token as HttpOnly cookie (not in URL — avoids leaks via
# Referrer headers, browser history, and server logs)
is_production = bool(os.environ.get("SPACE_HOST"))
response = RedirectResponse(url="/", status_code=302)
response.set_cookie(
key="hf_access_token",
value=access_token,
httponly=True,
secure=is_production, # Secure flag only in production (HTTPS)
samesite="lax",
max_age=3600 * 24 * 7, # 7 days
path="/",
)
response.set_cookie(
key=OAUTH_SCOPE_COOKIE,
value=oauth_scope_fingerprint(OAUTH_SCOPES),
httponly=True,
secure=is_production,
samesite="lax",
max_age=3600 * 24 * 7,
path="/",
)
return response
@router.get("/logout")
async def logout() -> RedirectResponse:
"""Log out the user by clearing the auth cookie."""
response = RedirectResponse(url="/")
response.delete_cookie(key="hf_access_token", path="/")
response.delete_cookie(key=OAUTH_SCOPE_COOKIE, path="/")
return response
@router.get("/status")
async def auth_status() -> dict:
"""Check if OAuth is enabled on this instance."""
return {"auth_enabled": AUTH_ENABLED}
@router.get("/me")
async def get_me(user: dict = Depends(get_current_user)) -> dict:
"""Get current user info. Returns the authenticated user or dev user.
Uses the shared auth dependency which handles cookie + Bearer token.
"""
return {key: value for key, value in user.items() if not key.startswith("_")}
|