water3 / dependencies.py
onewayto's picture
Upload 187 files
070daf8 verified
"""Authentication dependencies for FastAPI routes.
Provides auth validation for both REST and WebSocket endpoints.
- In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user.
- In production: validates Bearer tokens or cookies against HF OAuth.
"""
import logging
import os
import time
from typing import Any
import httpx
from fastapi import HTTPException, Request, WebSocket, status
logger = logging.getLogger(__name__)
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", ""))
# Simple in-memory token cache: token -> (user_info, expiry_time)
_token_cache: dict[str, tuple[dict[str, Any], float]] = {}
TOKEN_CACHE_TTL = 300 # 5 minutes
DEV_USER: dict[str, Any] = {
"user_id": "dev",
"username": "dev",
"authenticated": True,
}
async def _validate_token(token: str) -> dict[str, Any] | None:
"""Validate a token against HF OAuth userinfo endpoint.
Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls.
"""
now = time.time()
# Check cache
if token in _token_cache:
user_info, expiry = _token_cache[token]
if now < expiry:
return user_info
del _token_cache[token]
# Validate against HF
async with httpx.AsyncClient(timeout=10.0) as client:
try:
response = await client.get(
f"{OPENID_PROVIDER_URL}/oauth/userinfo",
headers={"Authorization": f"Bearer {token}"},
)
if response.status_code != 200:
logger.debug("Token validation failed: status %d", response.status_code)
return None
user_info = response.json()
_token_cache[token] = (user_info, now + TOKEN_CACHE_TTL)
return user_info
except httpx.HTTPError as e:
logger.warning("Token validation error: %s", e)
return None
def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
"""Build a normalized user dict from HF userinfo response."""
return {
"user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")),
"username": user_info.get("preferred_username", "unknown"),
"name": user_info.get("name"),
"picture": user_info.get("picture"),
"authenticated": True,
}
async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
"""Validate a token and return a user dict, or None."""
user_info = await _validate_token(token)
if user_info:
return _user_from_info(user_info)
return None
async def get_current_user(request: Request) -> dict[str, Any]:
"""FastAPI dependency: extract and validate the current user.
Checks (in order):
1. Authorization: Bearer <token> header
2. hf_access_token cookie
In dev mode (AUTH_ENABLED=False), returns a default dev user.
"""
if not AUTH_ENABLED:
return DEV_USER
# Try Authorization header
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
user = await _extract_user_from_token(token)
if user:
return user
# Try cookie
token = request.cookies.get("hf_access_token")
if token:
user = await _extract_user_from_token(token)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated. Please log in via /auth/login.",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None:
"""Extract and validate user from WebSocket connection.
WebSocket doesn't support custom headers from browser, so we check:
1. ?token= query parameter
2. hf_access_token cookie (sent automatically for same-origin)
Returns user dict or None if not authenticated.
In dev mode, returns the default dev user.
"""
if not AUTH_ENABLED:
return DEV_USER
# Try query param
token = websocket.query_params.get("token")
if token:
user = await _extract_user_from_token(token)
if user:
return user
# Try cookie (works for same-origin WebSocket)
token = websocket.cookies.get("hf_access_token")
if token:
user = await _extract_user_from_token(token)
if user:
return user
return None